{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Various torch packages\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# torchvision\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "# ------------------------\n",
    "# get up one directory \n",
    "import sys, os\n",
    "sys.path.append(os.path.abspath('../'))\n",
    "# ------------------------\n",
    "\n",
    "# custom packages\n",
    "import models.aux_funs as maf\n",
    "import optimizers as op\n",
    "import regularizers as reg\n",
    "import train\n",
    "import math\n",
    "import utils.configuration as cf\n",
    "import utils.datasets as ud\n",
    "from utils.datasets import get_data_set, GaussianSmoothing\n",
    "from models.fully_connected import fully_connected\n",
    "from models.resnet import ResNet18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -----------------------------------------------------------------------------------\n",
    "# Fix random seed\n",
    "# -----------------------------------------------------------------------------------\n",
    "random_seed = 2\n",
    "cf.seed_torch(random_seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# -----------------------------------------------------------------------------------\n",
    "# Parameters\n",
    "# -----------------------------------------------------------------------------------\n",
    "conf_args = {#\n",
    "    # data specification\n",
    "    'data_file':\"../../Data\", 'train_split':0.95, 'data_set':\"CIFAR10\", 'download':False,\n",
    "    # cuda\n",
    "    'use_cuda':True, 'num_workers':4, 'cuda_device':0, 'pin_memory':True,\n",
    "    #\n",
    "    'epochs':100,\n",
    "    # optimizer\n",
    "    'delta':1.0, 'lr':0.001, 'lamda_0':0.1, 'lamda_1':0.0, 'optim':\"AdaBreg\", 'conv_group':True,\n",
    "    'beta':0.0,\n",
    "    # initialization\n",
    "    'sparse_init':0.01, 'r':[1,5,1],\n",
    "    # misc\n",
    "    'random_seed':random_seed, 'eval_acc':True,\n",
    "}\n",
    "conf = cf.Conf(**conf_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "ResNet18() got an unexpected keyword argument 'mean'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[10], line 7\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[39m# -----------------------------------------------------------------------------------\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[39m# define the model and an instance of the best model class\u001b[39;00m\n\u001b[1;32m      3\u001b[0m \u001b[39m# -----------------------------------------------------------------------------------\u001b[39;00m\n\u001b[1;32m      4\u001b[0m model_kwargs \u001b[39m=\u001b[39m {\u001b[39m'\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m'\u001b[39m:conf\u001b[39m.\u001b[39mdata_set_mean, \u001b[39m'\u001b[39m\u001b[39mstd\u001b[39m\u001b[39m'\u001b[39m:conf\u001b[39m.\u001b[39mdata_set_std, \u001b[39m'\u001b[39m\u001b[39mim_size\u001b[39m\u001b[39m'\u001b[39m:conf\u001b[39m.\u001b[39mim_shape[\u001b[39m1\u001b[39m],\n\u001b[1;32m      5\u001b[0m                 \u001b[39m'\u001b[39m\u001b[39mim_channels\u001b[39m\u001b[39m'\u001b[39m:conf\u001b[39m.\u001b[39mim_shape[\u001b[39m0\u001b[39m]}    \n\u001b[0;32m----> 7\u001b[0m model \u001b[39m=\u001b[39m ResNet18(mean \u001b[39m=\u001b[39;49m conf\u001b[39m.\u001b[39;49mdata_set_mean, std \u001b[39m=\u001b[39;49m conf\u001b[39m.\u001b[39;49mdata_set_std)\n\u001b[1;32m      8\u001b[0m best_model \u001b[39m=\u001b[39m train\u001b[39m.\u001b[39mbest_model(ResNet18(mean \u001b[39m=\u001b[39m conf\u001b[39m.\u001b[39mdata_set_mean, std \u001b[39m=\u001b[39m conf\u001b[39m.\u001b[39mdata_set_std)\u001b[39m.\u001b[39mto(conf\u001b[39m.\u001b[39mdevice))\n\u001b[1;32m     10\u001b[0m \u001b[39m# sparsify\u001b[39;00m\n",
      "\u001b[0;31mTypeError\u001b[0m: ResNet18() got an unexpected keyword argument 'mean'"
     ]
    }
   ],
   "source": [
    "# -----------------------------------------------------------------------------------\n",
    "# define the model and an instance of the best model class\n",
    "# -----------------------------------------------------------------------------------\n",
    "model_kwargs = {'mean':conf.data_set_mean, 'std':conf.data_set_std, 'im_size':conf.im_shape[1],\n",
    "                'im_channels':conf.im_shape[0]}    \n",
    "\n",
    "model = ResNet18(mean = conf.data_set_mean, std = conf.data_set_std)\n",
    "best_model = train.best_model(ResNet18(mean = conf.data_set_mean, std = conf.data_set_std).to(conf.device))\n",
    "    \n",
    "# sparsify\n",
    "maf.sparse_bias_uniform_(model, 0,conf.r[0])\n",
    "maf.sparse_weight_normal_(model, conf.r[1])\n",
    "maf.sparsify_(model, conf.sparse_init, conv_group = conf.conv_group)\n",
    "model = model.to(conf.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -----------------------------------------------------------------------------------\n",
    "# define the model and an instance of the best model class\n",
    "# -----------------------------------------------------------------------------------\n",
    "model_kwargs = {'mean':conf.data_set_mean, 'std':conf.data_set_std}    \n",
    "\n",
    "def init_weights(conf, model):\n",
    "    # sparsify\n",
    "    maf.sparse_bias_uniform_(model, 0, conf.r[0])\n",
    "    maf.sparse_weight_normal_(model, conf.r[1])\n",
    "    maf.sparse_weight_normal_(model, conf.r[1],ltype=nn.Conv2d)\n",
    "    maf.sparsify_(model, conf.sparse_init, ltype = nn.Conv2d, conv_group = conf.conv_group)\n",
    "    model = model.to(conf.device)\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# -----------------------------------------------------------------------------------\n",
    "# Optimizer\n",
    "# -----------------------------------------------------------------------------------\n",
    "def init_opt(conf, model):\n",
    "    # Get access to different model parameters\n",
    "    weights_conv = maf.get_weights_conv(model)\n",
    "    weights_linear = maf.get_weights_linear(model)\n",
    "    weights_batch = maf.get_weights_batch(model)\n",
    "    biases = maf.get_bias(model)\n",
    "\n",
    "   \n",
    "    if conf.conv_group:\n",
    "        reg_0 = reg.reg_l1_l2_conv(conf.lamda_0)\n",
    "    else:\n",
    "        reg_0 = reg.reg_l1(conf.lamda)\n",
    "        \n",
    "    reg_1 = reg.reg_l1(conf.lamda_1)\n",
    "        \n",
    "        \n",
    "    if conf.optim == \"SGD\":\n",
    "        opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=beta)\n",
    "    elif conf.optim == \"LinBreg\":\n",
    "        opt = op.LinBreg([{'params': weights_conv, 'lr' : conf.lr, 'reg' : reg_0, 'momentum':conf.beta},\n",
    "                          {'params': weights_linear, 'lr' : conf.lr, 'reg' : reg_1, 'momentum':conf.beta},\n",
    "                          {'params': weights_batch, 'lr' : conf.lr, 'momentum':beta},\n",
    "                          {'params': biases, 'lr': conf.lr, 'momentum':beta}])\n",
    "    elif conf.optim == \"AdaBreg\":\n",
    "        opt = op.AdaBreg([{'params': weights_conv, 'lr' : conf.lr, 'reg' : reg_0},\n",
    "                           {'params': weights_linear, 'lr' : conf.lr, 'reg' : reg_1},\n",
    "                           {'params': weights_batch, 'lr' : conf.lr, 'momentum':conf.beta},\n",
    "                           {'params': biases, 'lr': conf.lr}])\n",
    "    elif conf.optim == \"L1SGD\":\n",
    "        def weight_reg(model):\n",
    "            reg_0 =  reg.reg_l1(lamda=conf.lamda_0)\n",
    "            \n",
    "            loss_1 = 0\n",
    "            for w in maf.get_weights_conv(model):\n",
    "                loss_1 += reg_1(w)\n",
    "                \n",
    "            loss_0 = 0\n",
    "            for w in maf.get_weights_linear(model):\n",
    "                loss_0 += reg_0(w)\n",
    "                \n",
    "            return loss_0 + loss_1\n",
    "        \n",
    "        conf.weight_reg = weight_reg\n",
    "        \n",
    "        opt = torch.optim.SGD(model.parameters(), lr=conf.lr, momentum=conf.beta)\n",
    "    else:\n",
    "        raise ValueError(\"Unknown Optimizer specified\")\n",
    "\n",
    "    # learning rate scheduler\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=5,threshold=0.01)\n",
    "    \n",
    "    return opt, scheduler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, valid_loader, test_loader = ud.get_data_set(conf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# History and Runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize history\n",
    "tracked = ['loss', 'node_sparse']\n",
    "train_hist = {key: [] for key in tracked}\n",
    "val_hist = {key: [] for key in tracked}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# -----------------------------------------------------------------------------------\n",
    "# Reinit weigts and the corresponding optimizer\n",
    "# -----------------------------------------------------------------------------------\n",
    "model = init_weights(conf, model)\n",
    "opt, scheduler = init_opt(conf, model)\n",
    "\n",
    "# -----------------------------------------------------------------------------------\n",
    "# train the model\n",
    "# -----------------------------------------------------------------------------------\n",
    "for epoch in range(conf.epochs):\n",
    "    print(25*\"<>\")\n",
    "    print(50*\"|\")\n",
    "    print(25*\"<>\")\n",
    "    print('Epoch:', epoch)\n",
    "\n",
    "    # ------------------------------------------------------------------------\n",
    "    # train step, log the accuracy and loss\n",
    "    # ------------------------------------------------------------------------\n",
    "    train_data = train.train_step(conf, model, opt, train_loader)\n",
    "\n",
    "    # update history\n",
    "    for key in tracked:\n",
    "        if key in train_data:\n",
    "            var_list = train_hist.setdefault(key, [])\n",
    "            var_list.append(train_data[key])        \n",
    "\n",
    "    # ------------------------------------------------------------------------\n",
    "    # validation step\n",
    "    val_data = train.validation_step(conf, model, opt, valid_loader)\n",
    "\n",
    "    # update history\n",
    "    for key in tracked:\n",
    "        if key in val_data:\n",
    "            var = val_data[key]\n",
    "            if isinstance(var, list):\n",
    "                for i, var_loc in enumerate(var):\n",
    "                    key_loc = key+\"_\" + str(i)\n",
    "                    var_list = val_hist.setdefault(key_loc, [])\n",
    "                    val_hist[key_loc].append(var_loc)\n",
    "\n",
    "\n",
    "    scheduler.step(train_data['loss'])\n",
    "    print(\"Learning rate:\",opt.param_groups[0]['lr'])\n",
    "    best_model(train_data['acc'], val_data['acc'], model=model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
